Statistical and Machine Learning

Lab6: Classification and Regression Tree (CART)

Tsai, Dai-Rong

Dataset

Sales of Child Car Seats

A simulated data set containing sales of child car seats at 400 different stores.

# Set random seed
set.seed(123)

# Packages
library(rpart) # for rpart
library(rpart.plot) # for rpart.plot

# Data
data(Carseats, package = "ISLR")
Carseats <- transform(Carseats, High = factor(ifelse(Sales <= 8, "No", "Yes")))
  • Response
    • Sales (continuous): Unit sales (in thousands) at each location
    • High (categorical): A factor with levels No and Yes to indicate whether the Sales variable exceeds 8
  • Predictors
    • CompPrice: Price charged by competitor at each location
    • Income: Community income level (in thousands of dollars)
    • Advertising: Local advertising budget for company at each location (in thousands of dollars)
    • Population: Population size in region (in thousands)
    • Price: Price company charges for car seats at each site
    • ShelveLoc: A factor with levels Bad, Good and Medium indicating the quality of the shelving location for the car seats at each site
    • Age: Average age of the local population
    • Education: Education level at each location
    • Urban: A factor with levels No and Yes to indicate whether the store is in an urban or rural location
    • US: A factor with levels No and Yes to indicate whether the store is in the US or not
dim(Carseats)
[1] 400  12
head(Carseats)
  Sales CompPrice Income Advertising Population Price ShelveLoc Age Education Urban  US High
1  9.50       138     73          11        276   120       Bad  42        17   Yes Yes  Yes
2 11.22       111     48          16        260    83      Good  65        10   Yes Yes  Yes
3 10.06       113     35          10        269    80    Medium  59        12   Yes Yes  Yes
4  7.40       117    100           4        466    97    Medium  55        14   Yes Yes   No
5  4.15       141     64           3        340   128       Bad  38        13   Yes  No   No
6 10.81       124    113          13        501    72       Bad  78        16    No Yes  Yes
table(Carseats$High)

 No Yes 
236 164 
hist(Carseats$Sales, 30); abline(v = 8, col = 2, lwd = 3)

str(Carseats)
'data.frame':   400 obs. of  12 variables:
 $ Sales      : num  9.5 11.22 10.06 7.4 4.15 ...
 $ CompPrice  : num  138 111 113 117 141 124 115 136 132 132 ...
 $ Income     : num  73 48 35 100 64 113 105 81 110 113 ...
 $ Advertising: num  11 16 10 4 3 13 0 15 0 0 ...
 $ Population : num  276 260 269 466 340 501 45 425 108 131 ...
 $ Price      : num  120 83 80 97 128 72 108 120 124 124 ...
 $ ShelveLoc  : Factor w/ 3 levels "Bad","Good","Medium": 1 2 3 3 1 1 3 2 3 3 ...
 $ Age        : num  42 65 59 55 38 78 71 67 76 76 ...
 $ Education  : num  17 10 12 14 13 16 15 10 10 17 ...
 $ Urban      : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 1 2 2 1 1 ...
 $ US         : Factor w/ 2 levels "No","Yes": 2 2 2 2 1 2 1 2 1 2 ...
 $ High       : Factor w/ 2 levels "No","Yes": 2 2 2 1 1 2 1 2 1 1 ...

Create Training/Testing Partitions

  • Split data into 70% training set and 30% test set
nr <- nrow(Carseats)
train.id <- sample(nr, nr * 0.7)

training <- Carseats[train.id, ]
testing <- Carseats[-train.id, ]
  • Check dimension
dim(training)
[1] 280  12
dim(testing)
[1] 120  12

Classification Tree

Source: An Introduction to Recursive Partitioning Using the rpart Routines

ctree.mod <- rpart(High ~ . - Sales, data = training, method = "class",
                   control = rpart.control(cp = 0.001))

Arguments

  • method: the type of splitting rule to use.
    • "class" default when y is a factor.
    • "anova" default when y is numeric.
    • "exp": default when y is a survival object.
    • "poisson": default when y has 2 columns.
  • parms = list(...): a list of method specific optional parameters. For classification, the list can contain any of:
    • prior: the vector of prior probabilities which must be positive and sum to 1.
    • loss: the loss matrix with zeros on the diagonal and positive off-diagonal elements.
    • split: the splitting index, "gini" or "information"(cross-entropy).
  • control = rpart.control(...): a list of options that control details of the rpart algorithm.
    • cp: (default: 0.01) the scaled cost-complexity parameter. Any split that does not decrease the overall lack-of-fit by a factor of cp is not attempted. \[ \begin{aligned} R_\alpha(T) &= R(T) + \alpha \cdot |T| \\ \xrightarrow{cp = \frac{\alpha}{R(T_1)}} R_{cp}(T) &= R(T) + cp \cdot R(T_1) \cdot |T| \end{aligned} \]

      where \(T_1\) is the tree with no splits, \(|T|\) is the number of splits for a tree \(T\), and \(R(\cdot)\) is the risk.

      • cp = 0: full model
      • cp = 1: null model / model with no splits
    • minsplit: (default: 20) the minimum number of observations that must exist in a node in order for a split to be attempted. This parameter can save computation time, since smaller nodes are almost always pruned away by cross-validation.

    • maxdepth: (default: 30) the maximum depth of the tree, with the root node counted as depth 0.

    • xval: (default: 10) the number of cross-validations.

rpart.plot(ctree.mod)
rpart.plot(ctree.mod, extra = 2, box.palette = "BuOr", shadow.col = "gray", fallen.leaves = FALSE)

Variable Importance

ctree.mod$variable.importance
  ShelveLoc       Price Advertising         Age   CompPrice   Education      Income          US 
   23.61047    19.16682    13.32721    10.31405     8.89261     5.35242     2.96884     2.84405 
      Urban  Population 
    1.30643     0.75814 
barplot(rev(ctree.mod$variable.importance), horiz = TRUE, las = 1,
        cex.names = 0.7, col = "skyblue")

Tree Pruning

printcp(ctree.mod)

Classification tree:
rpart(formula = High ~ . - Sales, data = training, method = "class", 
    control = rpart.control(cp = 0.001))

Variables actually used in tree construction:
[1] Advertising Age         CompPrice   Education   Price       ShelveLoc  

Root node error: 115/280 = 0.411

n= 280 

      CP nsplit rel error xerror   xstd
1 0.2783      0     1.000  1.000 0.0716
2 0.0696      1     0.722  0.722 0.0665
3 0.0522      3     0.583  0.783 0.0680
4 0.0435      4     0.530  0.800 0.0683
5 0.0304      5     0.487  0.783 0.0680
6 0.0261      7     0.426  0.678 0.0652
7 0.0130      8     0.400  0.739 0.0669
8 0.0010     10     0.374  0.713 0.0662
plotcp(ctree.mod)

  • Horizontal line: 1SE above the minimum of the curve.
  • A good choice of cp for pruning is often the leftmost value for which the mean lies below the horizontal line.
bestcp <- ctree.mod$cptable[which.min(ctree.mod$cptable[, "xerror"]), "CP"]
bestcp
[1] 0.026087
ctree.mod.pruned <- prune(ctree.mod, cp = bestcp)
rpart.plot(ctree.mod.pruned, box.palette = "BuOr", shadow.col = "gray")

Prediction

  • See ?predict.rpart
  • type: “vector”, “prob”, “class”, “matrix”
  • Unpruned tree
pred.ctree <- predict(ctree.mod, testing, type = "class")
table(true = testing$High, pred = pred.ctree)
     pred
true  No Yes
  No  56  15
  Yes 19  30
# Accuracy
mean(testing$High == pred.ctree)
[1] 0.71667
  • Pruned tree
pred.ctree.pruned <- predict(ctree.mod.pruned, testing, type = "class")
table(true = testing$High, pred = pred.ctree.pruned)
     pred
true  No Yes
  No  59  12
  Yes 19  30
# Accuracy
mean(testing$High == pred.ctree.pruned)
[1] 0.74167

Regression Tree

regtree.mod <- rpart(Sales ~ . - High, data = training, method = "anova",
                     control = rpart.control(cp = 0.001))
printcp(regtree.mod)

Regression tree:
rpart(formula = Sales ~ . - High, data = training, method = "anova", 
    control = rpart.control(cp = 0.001))

Variables actually used in tree construction:
[1] Advertising Age         CompPrice   Price       ShelveLoc  

Root node error: 2227/280 = 7.95

n= 280 

        CP nsplit rel error xerror   xstd
1  0.25462      0     1.000  1.011 0.0848
2  0.09215      1     0.745  0.759 0.0593
3  0.07090      2     0.653  0.742 0.0591
4  0.04325      3     0.582  0.732 0.0576
5  0.03605      4     0.539  0.690 0.0536
6  0.03227      5     0.503  0.676 0.0542
7  0.02429      7     0.438  0.636 0.0511
8  0.01748      8     0.414  0.612 0.0498
9  0.01592      9     0.397  0.631 0.0504
10 0.01563     10     0.381  0.631 0.0506
11 0.01413     11     0.365  0.633 0.0535
12 0.01354     12     0.351  0.625 0.0534
13 0.01265     14     0.324  0.630 0.0546
14 0.01046     15     0.311  0.630 0.0542
15 0.00925     16     0.301  0.625 0.0536
16 0.00888     17     0.292  0.619 0.0538
17 0.00778     18     0.283  0.607 0.0508
18 0.00695     20     0.267  0.596 0.0506
19 0.00680     21     0.260  0.595 0.0504
20 0.00530     22     0.253  0.585 0.0500
 [ reached getOption("max.print") -- omitted 3 rows ]
plotcp(regtree.mod)

bestcp <- regtree.mod$cptable[which.min(regtree.mod$cptable[, "xerror"]), "CP"]
bestcp
[1] 0.001
regtree.mod.pruned <- prune(regtree.mod, cp = bestcp)
rpart.plot(regtree.mod.pruned)

Prediction

pred.regtree.pruned <- predict(regtree.mod.pruned, testing)

# MSE
mean((testing$Sales - pred.regtree.pruned)^2)
[1] 3.5192